{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "file_extension": ".py",
    "kernelspec": {
      "display_name": "Python 3.7.7 64-bit ('datasets': conda)",
      "language": "python",
      "name": "python37764bitdatasetscondae5d8ff60608e4c5c953d6bb643d8ebc5"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.7"
    },
    "mimetype": "text/x-python",
    "name": "python",
    "npconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": 3,
    "colab": {
      "name": "HuggingFace nlp library - Overview",
      "provenance": [],
      "include_colab_link": true
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "227636ea1d9b40b5ac28cef9cb28b4c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_a5bdbfc40d4e445c8c6e7e1a31a79c39",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_edbd3cac20ce48b28004e106d6364d0b",
              "IPY_MODEL_16f9c946d61943b7b4d4359c89ce31c5"
            ]
          }
        },
        "a5bdbfc40d4e445c8c6e7e1a31a79c39": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "edbd3cac20ce48b28004e106d6364d0b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_23644534ba354f3389fb53e3a9d07767",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 4997,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 4997,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_6c8d292d051045848bec0462cd0b85ca"
          }
        },
        "16f9c946d61943b7b4d4359c89ce31c5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_36d0bbf33e4e4ca4a97966f4748d348c",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 5.00k/5.00k [00:01&lt;00:00, 3.04kB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_0afa8cbe1f9641a6a0f5ce521625ac22"
          }
        },
        "23644534ba354f3389fb53e3a9d07767": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "6c8d292d051045848bec0462cd0b85ca": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "36d0bbf33e4e4ca4a97966f4748d348c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "0afa8cbe1f9641a6a0f5ce521625ac22": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9a03f10dd7ed4f54b9b093d6e99b510b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_c03eb6bd0c0e440dabbebad1b37d5b37",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_f55eae0481714b80aa7a30145486d4c8",
              "IPY_MODEL_1eae9b7e6ab54bf3b9d2c546c9fa560c"
            ]
          }
        },
        "c03eb6bd0c0e440dabbebad1b37d5b37": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f55eae0481714b80aa7a30145486d4c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_ef4f1ad6f07f4a3b9100d05c86685146",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 2240,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 2240,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_73ff00b44b264bce83e0a534a7697ee9"
          }
        },
        "1eae9b7e6ab54bf3b9d2c546c9fa560c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_ef600a6844fe4557b27c857520733cdd",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 2.24k/2.24k [00:05&lt;00:00, 396B/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_224a993375ca44e2bb9ee1c4c559a28f"
          }
        },
        "ef4f1ad6f07f4a3b9100d05c86685146": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "73ff00b44b264bce83e0a534a7697ee9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "ef600a6844fe4557b27c857520733cdd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "224a993375ca44e2bb9ee1c4c559a28f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "08505eb8a7c649ff811e0f10ea2a27b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_0cdc2e5f4c374239b3168b1261c60272",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_a3d61fc47d1a4fd6b805c9726eae3f3b",
              "IPY_MODEL_d020b1411bc1493eac1eed153dfc7c9c"
            ]
          }
        },
        "0cdc2e5f4c374239b3168b1261c60272": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a3d61fc47d1a4fd6b805c9726eae3f3b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_f507a7a4ffde4238847374b71070b33d",
            "_dom_classes": [],
            "description": "Downloading: ",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 8116577,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 8116577,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_dcf172324f264f2eac52b4fe87f46c01"
          }
        },
        "d020b1411bc1493eac1eed153dfc7c9c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_ea29d58feeea48cabee5cc4beed63da8",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 30.3M/? [00:01&lt;00:00, 19.6MB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_d388ffbc7d984f5f87f7e468ea235f40"
          }
        },
        "f507a7a4ffde4238847374b71070b33d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "dcf172324f264f2eac52b4fe87f46c01": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "ea29d58feeea48cabee5cc4beed63da8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "d388ffbc7d984f5f87f7e468ea235f40": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "6fcb4fa7272345bf8eb98b3732ffaed7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_b914d448c9e74a259e28f036b10b1103",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_90329e7c01064e5dacfba73a54c069c9",
              "IPY_MODEL_32a4dc10fb9847e9b9d1e3a23eb586f8"
            ]
          }
        },
        "b914d448c9e74a259e28f036b10b1103": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "90329e7c01064e5dacfba73a54c069c9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_1cce33d331af49d6bd783e12601ec28c",
            "_dom_classes": [],
            "description": "Downloading: ",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1054280,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1054280,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_97630224661c4b0db7bb9cf521686f24"
          }
        },
        "32a4dc10fb9847e9b9d1e3a23eb586f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_c48a2e3e9419495f8319fb560fafd30b",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 4.85M/? [00:00&lt;00:00, 13.4MB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_8df2f18abed046d48d57a14421f4734f"
          }
        },
        "1cce33d331af49d6bd783e12601ec28c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "97630224661c4b0db7bb9cf521686f24": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c48a2e3e9419495f8319fb560fafd30b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "8df2f18abed046d48d57a14421f4734f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "fb6cc99c79dd429fa1fcf560a5652433": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_5b49957191464f598e839eefbe5a7f9a",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_f5c3a667c46148cc90146515c67d8521",
              "IPY_MODEL_0f1953935cb047c2a124b2c2942ae9c8"
            ]
          }
        },
        "5b49957191464f598e839eefbe5a7f9a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f5c3a667c46148cc90146515c67d8521": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_70bc827e4f7c454d9281a46ad6968bd8",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_9d33f81817a748e8b51038a582ce3a54"
          }
        },
        "0f1953935cb047c2a124b2c2942ae9c8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_f7aa7ed00ccb45a984b6041c28f8e332",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 87599/0 [00:03&lt;00:00, 5305.14 examples/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_40dd2802be014e7d9d0d83fec33362e3"
          }
        },
        "70bc827e4f7c454d9281a46ad6968bd8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "9d33f81817a748e8b51038a582ce3a54": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f7aa7ed00ccb45a984b6041c28f8e332": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "40dd2802be014e7d9d0d83fec33362e3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c1371bab2a1a4f999b74dd2dd4680996": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_48ec76136e134c00a76049a275c09f9e",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_8fbdaba38e3643e094cfe9fc8a8ad0e5",
              "IPY_MODEL_b9b9b2bdd015443ca022e6f89d06b261"
            ]
          }
        },
        "48ec76136e134c00a76049a275c09f9e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "8fbdaba38e3643e094cfe9fc8a8ad0e5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_de6ae2f354e640008c8bacedda3239b1",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_6a5810eeb8684e8899c5c1d42f374f5f"
          }
        },
        "b9b9b2bdd015443ca022e6f89d06b261": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_7fe74a50fdf94259a26fe3f71390958f",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 10570/0 [00:00&lt;00:00, 17.06 examples/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_c913972b39594a7da6d4b7adfc0251f1"
          }
        },
        "de6ae2f354e640008c8bacedda3239b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "6a5810eeb8684e8899c5c1d42f374f5f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "7fe74a50fdf94259a26fe3f71390958f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "c913972b39594a7da6d4b7adfc0251f1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "b38b9470cc694dd2aa6d36982fecc94f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_6ad42f91046e4ba2a0fe329946bf57b6",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_931883ddbf7345dfbfcf091c53c4861a",
              "IPY_MODEL_db9de67daefd4065b390637c0294e2b6"
            ]
          }
        },
        "6ad42f91046e4ba2a0fe329946bf57b6": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "931883ddbf7345dfbfcf091c53c4861a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_5379dd7943304dfdb1e96162e1f85832",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 213450,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 213450,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_f557d0df4c664ec0a2014056ad70447a"
          }
        },
        "db9de67daefd4065b390637c0294e2b6": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_8b69d12f2711475e945c6a80b7516d5e",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 213k/213k [00:00&lt;00:00, 333kB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_36c77532c74b4f9388d8b182f9628f2b"
          }
        },
        "5379dd7943304dfdb1e96162e1f85832": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "f557d0df4c664ec0a2014056ad70447a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "8b69d12f2711475e945c6a80b7516d5e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "36c77532c74b4f9388d8b182f9628f2b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "454a7d19cdb04b7e93efec4e2b374345": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_adb39c17c61947ce9fb241311e929135",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_557a58848c2a498882429ad3b3efe29f",
              "IPY_MODEL_ea2a6e311e164c00b7e9801a58aa44c4"
            ]
          }
        },
        "adb39c17c61947ce9fb241311e929135": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "557a58848c2a498882429ad3b3efe29f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_72b94b71b4214c53891769dc1bd1d15d",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 411,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 411,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_cf063e09a69243d3af2f6a97f4f6417e"
          }
        },
        "ea2a6e311e164c00b7e9801a58aa44c4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_fafb257a0fe24cc9a8ea3e1a0b87e42b",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 411/411 [00:00&lt;00:00, 881B/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_0e9a15d300d547dbae5f6fc6a1df1782"
          }
        },
        "72b94b71b4214c53891769dc1bd1d15d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "cf063e09a69243d3af2f6a97f4f6417e": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "fafb257a0fe24cc9a8ea3e1a0b87e42b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "0e9a15d300d547dbae5f6fc6a1df1782": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "049655d977ec42ddaa2079422cc0a635": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_4807314ef3b342a1baf6dd2e0cd16cd5",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_508c6262c25a41de89be1499c973a126",
              "IPY_MODEL_a45ffab324f94d4882863e56b3d43678"
            ]
          }
        },
        "4807314ef3b342a1baf6dd2e0cd16cd5": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "508c6262c25a41de89be1499c973a126": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_7d3af3d2939c49e2b113d92e8573003f",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 263273408,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 263273408,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_5627b08bdb1f4a9fa2d58c63481e1129"
          }
        },
        "a45ffab324f94d4882863e56b3d43678": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_259b90b57d064c938b47d61c279ae237",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 263M/263M [00:12&lt;00:00, 20.5MB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_ef365e73c93340a983ace25acfa41c17"
          }
        },
        "7d3af3d2939c49e2b113d92e8573003f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "5627b08bdb1f4a9fa2d58c63481e1129": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "259b90b57d064c938b47d61c279ae237": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "ef365e73c93340a983ace25acfa41c17": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/huggingface/nlp/blob/master/notebooks/Overview.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zNp6kK7OvSUg",
        "colab_type": "text"
      },
      "source": [
        "# HuggingFace `nlp` library - Quick overview\n",
        "\n",
        "Models come and go (linear models, LSTM, Transformers, ...) but two core elements have consistently been the beating heart of Natural Language Processing: Datasets & Metrics\n",
        "\n",
        "`nlp` is a lightweight and extensible library to easily share and load dataset and evaluation metrics, already providing access to ~100 datasets and ~10 evaluation metrics.\n",
        "\n",
        "The library has several interesting features (beside easy access to datasets/metrics):\n",
        "\n",
        "- Build-in interoperability with PyTorch, Tensorflow 2, Pandas and Numpy\n",
        "- Small and fast library with a transparent and pythonic API\n",
        "- Strive on large datasets: nlp naturally frees you from RAM memory limits, all datasets are memory-mapped on drive by default.\n",
        "- Smart caching with an intelligent `tf.data`-like cache: never wait for your data to process several times\n",
        "\n",
        "`nlp` originated from a fork of the awesome Tensorflow-Datasets and the HuggingFace team want to deeply thank the team behind this amazing library and user API. We have tried to keep a layer of compatibility with `tfds` and a conversion can provide conversion from one format to the other."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dzk9aEtIvSUh",
        "colab_type": "text"
      },
      "source": [
        "# Main datasets API\n",
        "\n",
        "This notebook is a quick dive in the main user API for loading datasets in `nlp`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "my95uHbLyjwR",
        "colab_type": "code",
        "outputId": "b8d02789-6168-4b28-ac8b-b22975e7c5fa",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 340
        }
      },
      "source": [
        "!pip install nlp==0.1.0"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting nlp==0.1.0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/21/17/0a408ac3403c71c978e7906146c69ed00b18712a4548e34d0ffb567c34cc/nlp-0.1.0-py3-none-any.whl (87kB)\n",
            "\u001b[K     |████████████████████████████████| 92kB 2.5MB/s \n",
            "\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (2.23.0)\n",
            "Requirement already satisfied: pyarrow>=0.16.0 in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (0.17.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (4.41.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (3.0.12)\n",
            "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (0.7)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (1.18.4)\n",
            "Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from nlp==0.1.0) (0.3.1.1)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->nlp==0.1.0) (3.0.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->nlp==0.1.0) (2020.4.5.1)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->nlp==0.1.0) (1.24.3)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->nlp==0.1.0) (2.9)\n",
            "Installing collected packages: nlp\n",
            "  Found existing installation: nlp 0.0.3\n",
            "    Uninstalling nlp-0.0.3:\n",
            "      Successfully uninstalled nlp-0.0.3\n",
            "Successfully installed nlp-0.1.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hJHyEmievSUh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import logging\n",
        "logging.basicConfig(level=logging.INFO)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PVjXLiYxvSUl",
        "colab_type": "code",
        "outputId": "c9bddcbc-dbaa-4df5-f2c8-eb2f0420c5f4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "# Let's import the library\n",
        "import nlp"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.utils.file_utils:PyTorch version 1.5.0+cu101 available.\n",
            "INFO:nlp.utils.file_utils:TensorFlow version 2.2.0 available.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TNloBBx-vSUo",
        "colab_type": "text"
      },
      "source": [
        "## Listing the currently available datasets and metrics"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d3RJisGLvSUp",
        "colab_type": "code",
        "outputId": "9359d476-d0e7-410d-a2e5-7a988e5974dc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "# Currently available datasets and metrics\n",
        "datasets = nlp.list_datasets()\n",
        "metrics = nlp.list_metrics()\n",
        "\n",
        "print(f\"🤩 Currently {len(datasets)} datasets are available on HuggingFace AWS bucket: \\n\" \n",
        "      + '\\n'.join(dataset.id for dataset in datasets) + '\\n')\n",
        "print(f\"🤩 Currently {len(metrics)} metrics are available on HuggingFace AWS bucket: \\n\" \n",
        "      + '\\n'.join(metric.id for metric in metrics))"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "🤩 Currently 93 datasets are available on HuggingFace AWS bucket: \n",
            "aeslc\n",
            "ai2_arc\n",
            "anli\n",
            "billsum\n",
            "blimp\n",
            "blog_authorship_corpus\n",
            "boolq\n",
            "break_data\n",
            "cfq\n",
            "civil_comments\n",
            "cmrc2018\n",
            "cnn_dailymail\n",
            "coarse_discourse\n",
            "com_qa\n",
            "commonsense_qa\n",
            "coqa\n",
            "cornell_movie_dialog\n",
            "cos_e\n",
            "cosmos_qa\n",
            "crime_and_punish\n",
            "csv\n",
            "definite_pronoun_resolution\n",
            "discofuse\n",
            "drop\n",
            "empathetic_dialogues\n",
            "eraser_multi_rc\n",
            "esnli\n",
            "event2Mind\n",
            "flores\n",
            "fquad\n",
            "gap\n",
            "gigaword\n",
            "glue\n",
            "hansards\n",
            "hellaswag\n",
            "imdb\n",
            "jeopardy\n",
            "librispeech_lm\n",
            "lm1b\n",
            "math_dataset\n",
            "mlqa\n",
            "movie_rationales\n",
            "multi_news\n",
            "multi_nli\n",
            "multi_nli_mismatch\n",
            "newsroom\n",
            "openbookqa\n",
            "opinosis\n",
            "para_crawl\n",
            "qa4mre\n",
            "qangaroo\n",
            "qasc\n",
            "quarel\n",
            "quartz\n",
            "quoref\n",
            "race\n",
            "reclor\n",
            "reddit\n",
            "reddit_tifu\n",
            "scan\n",
            "scicite\n",
            "scientific_papers\n",
            "scifact\n",
            "sciq\n",
            "scitail\n",
            "sentiment140\n",
            "snli\n",
            "squad\n",
            "squad_it\n",
            "squad_v1_pt\n",
            "squad_v2\n",
            "super_glue\n",
            "ted_hrlr\n",
            "ted_multi\n",
            "tiny_shakespeare\n",
            "trivia_qa\n",
            "tydiqa\n",
            "webis/tl_dr\n",
            "wiki40b\n",
            "wiki_qa\n",
            "wiki_split\n",
            "wikihow\n",
            "wikipedia\n",
            "wikitext\n",
            "winogrande\n",
            "wiqa\n",
            "x_stance\n",
            "xcopa\n",
            "xnli\n",
            "xquad\n",
            "xsum\n",
            "xtreme\n",
            "yelp_polarity\n",
            "\n",
            "🤩 Currently 10 metrics are available on HuggingFace AWS bucket: \n",
            "bleu\n",
            "coval\n",
            "gleu\n",
            "glue\n",
            "rouge\n",
            "sacrebleu\n",
            "seqeval\n",
            "squad\n",
            "squad_v2\n",
            "xnli\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7T5AG3BxvSUr",
        "colab_type": "code",
        "outputId": "9ccd5444-3aeb-489c-bd66-458fc4e34819",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 360
        }
      },
      "source": [
        "# You can read a few attributes of the datasets before loading them (they are python dataclasses)\n",
        "from dataclasses import asdict\n",
        "\n",
        "for key, value in asdict(datasets[6]).items():\n",
        "    print('👉 ' + key + ': ' + str(value))"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "👉 id: boolq\n",
            "👉 key: nlp/datasets/boolq/boolq.py\n",
            "👉 lastModified: 2020-05-14T14:57:19.000Z\n",
            "👉 description: \\\n",
            "BoolQ is a question answering dataset for yes/no questions containing 15942 examples. These questions are naturally \n",
            "occurring ---they are generated in unprompted and unconstrained settings. \n",
            "Each example is a triplet of (question, passage, answer), with the title of the page as optional additional context. \n",
            "The text-pair classification setup is similar to existing natural language inference tasks.\n",
            "👉 citation: \\\n",
            "@inproceedings{clark2019boolq,\n",
            "  title =     {BoolQ: Exploring the Surprising Difficulty of Natural Yes/No Questions},\n",
            "  author =    {Clark, Christopher and Lee, Kenton and Chang, Ming-Wei, and Kwiatkowski, Tom and Collins, Michael, and Toutanova, Kristina},\n",
            "  booktitle = {NAACL},\n",
            "  year =      {2019},\n",
            "}\n",
            "👉 size: 3444\n",
            "👉 etag: \"bf50ccf7a03c3167addb1032f83f883a\"\n",
            "👉 siblings: [{'key': 'nlp/datasets/boolq/boolq.py', 'etag': '\"bf50ccf7a03c3167addb1032f83f883a\"', 'lastModified': '2020-05-14T14:57:19.000Z', 'size': 3444, 'rfilename': 'boolq.py'}, {'key': 'nlp/datasets/boolq/dataset_infos.json', 'etag': '\"3cc498e42dc0b0298e7be42e9338ab77\"', 'lastModified': '2020-05-14T14:57:19.000Z', 'size': 1834, 'rfilename': 'dataset_infos.json'}, {'key': 'nlp/datasets/boolq/dummy/0.1.0/dummy_data.zip', 'etag': '\"880a58d1776b049ef92bbe9586bf6744\"', 'lastModified': '2020-05-14T14:57:20.000Z', 'size': 2368, 'rfilename': 'dummy/0.1.0/dummy_data.zip'}]\n",
            "👉 author: None\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9uqSkkSovSUt",
        "colab_type": "text"
      },
      "source": [
        "## An example with SQuAD"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aOXl6afcvSUu",
        "colab_type": "code",
        "outputId": "64812b74-9b98-4907-ada3-9f678e896d86",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 947,
          "referenced_widgets": [
            "227636ea1d9b40b5ac28cef9cb28b4c8",
            "a5bdbfc40d4e445c8c6e7e1a31a79c39",
            "edbd3cac20ce48b28004e106d6364d0b",
            "16f9c946d61943b7b4d4359c89ce31c5",
            "23644534ba354f3389fb53e3a9d07767",
            "6c8d292d051045848bec0462cd0b85ca",
            "36d0bbf33e4e4ca4a97966f4748d348c",
            "0afa8cbe1f9641a6a0f5ce521625ac22",
            "9a03f10dd7ed4f54b9b093d6e99b510b",
            "c03eb6bd0c0e440dabbebad1b37d5b37",
            "f55eae0481714b80aa7a30145486d4c8",
            "1eae9b7e6ab54bf3b9d2c546c9fa560c",
            "ef4f1ad6f07f4a3b9100d05c86685146",
            "73ff00b44b264bce83e0a534a7697ee9",
            "ef600a6844fe4557b27c857520733cdd",
            "224a993375ca44e2bb9ee1c4c559a28f",
            "08505eb8a7c649ff811e0f10ea2a27b1",
            "0cdc2e5f4c374239b3168b1261c60272",
            "a3d61fc47d1a4fd6b805c9726eae3f3b",
            "d020b1411bc1493eac1eed153dfc7c9c",
            "f507a7a4ffde4238847374b71070b33d",
            "dcf172324f264f2eac52b4fe87f46c01",
            "ea29d58feeea48cabee5cc4beed63da8",
            "d388ffbc7d984f5f87f7e468ea235f40",
            "6fcb4fa7272345bf8eb98b3732ffaed7",
            "b914d448c9e74a259e28f036b10b1103",
            "90329e7c01064e5dacfba73a54c069c9",
            "32a4dc10fb9847e9b9d1e3a23eb586f8",
            "1cce33d331af49d6bd783e12601ec28c",
            "97630224661c4b0db7bb9cf521686f24",
            "c48a2e3e9419495f8319fb560fafd30b",
            "8df2f18abed046d48d57a14421f4734f",
            "fb6cc99c79dd429fa1fcf560a5652433",
            "5b49957191464f598e839eefbe5a7f9a",
            "f5c3a667c46148cc90146515c67d8521",
            "0f1953935cb047c2a124b2c2942ae9c8",
            "70bc827e4f7c454d9281a46ad6968bd8",
            "9d33f81817a748e8b51038a582ce3a54",
            "f7aa7ed00ccb45a984b6041c28f8e332",
            "40dd2802be014e7d9d0d83fec33362e3",
            "c1371bab2a1a4f999b74dd2dd4680996",
            "48ec76136e134c00a76049a275c09f9e",
            "8fbdaba38e3643e094cfe9fc8a8ad0e5",
            "b9b9b2bdd015443ca022e6f89d06b261",
            "de6ae2f354e640008c8bacedda3239b1",
            "6a5810eeb8684e8899c5c1d42f374f5f",
            "7fe74a50fdf94259a26fe3f71390958f",
            "c913972b39594a7da6d4b7adfc0251f1"
          ]
        }
      },
      "source": [
        "# Downloading and loading a dataset\n",
        "\n",
        "dataset = nlp.load_dataset('squad', split='validation[:10%]')"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:filelock:Lock 140132949071464 acquired on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.utils.file_utils:https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/tmpztwikuej\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "227636ea1d9b40b5ac28cef9cb28b4c8",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4997.0, style=ProgressStyle(description…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.utils.file_utils:storing https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py in cache at /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py\n",
            "INFO:nlp.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py\n",
            "INFO:filelock:Lock 140132949071464 released on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:filelock:Lock 140132944020200 acquired on /root/.cache/huggingface/datasets/9ba53336b6bc977097b39b8527b06ec6ba3f60a44230f2a0a918735fcd8ad902.893fb39fe374e4c574667dd71a3017b7e2e1d196f3a34fb00b56bac805447f7c.lock\n",
            "INFO:nlp.utils.file_utils:https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/tmprzcikkr9\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9a03f10dd7ed4f54b9b093d6e99b510b",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2240.0, style=ProgressStyle(description…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.utils.file_utils:storing https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json in cache at /root/.cache/huggingface/datasets/9ba53336b6bc977097b39b8527b06ec6ba3f60a44230f2a0a918735fcd8ad902.893fb39fe374e4c574667dd71a3017b7e2e1d196f3a34fb00b56bac805447f7c\n",
            "INFO:nlp.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/9ba53336b6bc977097b39b8527b06ec6ba3f60a44230f2a0a918735fcd8ad902.893fb39fe374e4c574667dd71a3017b7e2e1d196f3a34fb00b56bac805447f7c\n",
            "INFO:filelock:Lock 140132944020200 released on /root/.cache/huggingface/datasets/9ba53336b6bc977097b39b8527b06ec6ba3f60a44230f2a0a918735fcd8ad902.893fb39fe374e4c574667dd71a3017b7e2e1d196f3a34fb00b56bac805447f7c.lock\n",
            "INFO:nlp.load:Checking /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py for additional imports.\n",
            "INFO:filelock:Lock 140132949070624 acquired on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.load:Creating main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad\n",
            "INFO:nlp.load:Creating specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670\n",
            "INFO:nlp.load:Copying script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/squad.py\n",
            "INFO:nlp.load:Copying dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/dataset_infos.json\n",
            "INFO:nlp.load:Creating metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/squad.json\n",
            "INFO:filelock:Lock 140132949070624 released on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.builder:No config specified, defaulting to first: squad/plain_text\n",
            "INFO:nlp.info:Loading Dataset Infos from /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670\n",
            "INFO:nlp.builder:Generating dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n",
            "Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.75 MiB, total: 119.27 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0...\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:filelock:Lock 140135689794560 acquired on /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.0964abf04a1620f820955438222e0dabf800b04589f4fad232502eafb17902e0.lock\n",
            "INFO:nlp.utils.file_utils:https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/tmpgbr8iigl\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "08505eb8a7c649ff811e0f10ea2a27b1",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=8116577.0, style=ProgressStyle(descript…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.utils.file_utils:storing https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json in cache at /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.0964abf04a1620f820955438222e0dabf800b04589f4fad232502eafb17902e0\n",
            "INFO:nlp.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.0964abf04a1620f820955438222e0dabf800b04589f4fad232502eafb17902e0\n",
            "INFO:filelock:Lock 140135689794560 released on /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.0964abf04a1620f820955438222e0dabf800b04589f4fad232502eafb17902e0.lock\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:filelock:Lock 140132943828360 acquired on /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.d80c10825b2cc6bbce928099e1426b391827370bac181b297931a597daf19886.lock\n",
            "INFO:nlp.utils.file_utils:https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json not found in cache or force_download set to True, downloading to /root/.cache/huggingface/datasets/downloads/tmpb2wrz3hx\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6fcb4fa7272345bf8eb98b3732ffaed7",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1054280.0, style=ProgressStyle(descript…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.utils.file_utils:storing https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json in cache at /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.d80c10825b2cc6bbce928099e1426b391827370bac181b297931a597daf19886\n",
            "INFO:nlp.utils.file_utils:creating metadata file for /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.d80c10825b2cc6bbce928099e1426b391827370bac181b297931a597daf19886\n",
            "INFO:filelock:Lock 140132943828360 released on /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.d80c10825b2cc6bbce928099e1426b391827370bac181b297931a597daf19886.lock\n",
            "INFO:nlp.utils.info_utils:All the checksums matched successfully.\n",
            "INFO:nlp.builder:Generating split train\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "fb6cc99c79dd429fa1fcf560a5652433",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:root:generating examples from = /root/.cache/huggingface/datasets/downloads/b8bb19735e1bb591510a01cc032f4c9f969bc0eeb081ae1b328cd306f3b24008.0964abf04a1620f820955438222e0dabf800b04589f4fad232502eafb17902e0\n",
            "INFO:nlp.arrow_writer:Done writing 87599 examples in 79317110 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0.incomplete/squad-train.arrow.\n",
            "INFO:nlp.builder:Generating split validation\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\r"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c1371bab2a1a4f999b74dd2dd4680996",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:root:generating examples from = /root/.cache/huggingface/datasets/downloads/9d5462987ef5f814fe15a369c1724f6ec39a2018b3b6271a9d7d2598686ca2ff.d80c10825b2cc6bbce928099e1426b391827370bac181b297931a597daf19886\n",
            "INFO:nlp.arrow_writer:Done writing 10570 examples in 10472653 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0.incomplete/squad-validation.arrow.\n",
            "INFO:nlp.utils.info_utils:All the splits matched successfully.\n",
            "INFO:nlp.builder:Constructing Dataset for split validation[:10%], from /root/.cache/huggingface/datasets/squad/plain_text/1.0.0\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\rDataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0. Subsequent calls will reuse this data.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rQ0G-eK3vSUw",
        "colab_type": "text"
      },
      "source": [
        "This call to `nlp.load_dataset()` does the following steps under the hood:\n",
        "\n",
        "1. Download and import in the library the **SQuAD python processing script** from HuggingFace AWS bucket if it's not already stored in the library. You can find the SQuAD processing script [here](https://github.com/huggingface/nlp/tree/master/datasets/squad/squad.py) for instance.\n",
        "\n",
        "   Processing scripts are small python scripts which define the info (citation, description) and format of the dataset and contain the URL to the original SQuAD JSON files and the code to load examples from the original SQuAD JSON files.\n",
        "\n",
        "\n",
        "2. Run the SQuAD python processing script which will:\n",
        "    - **Download the SQuAD dataset** from the original URL (see the script) if it's not already downloaded and cached.\n",
        "    - **Process and cache** all SQuAD in a structured Arrow table for each standard splits stored on the drive.\n",
        "\n",
        "      Arrow table are arbitrarly long tables, typed with types that can be mapped to numpy/pandas/python standard types and can store nested objects. They can be directly access from drive, loaded in RAM or even streamed over the web.\n",
        "    \n",
        "\n",
        "3. Return a **dataset build from the splits** asked by the user (default: all), in the above example we create a dataset with the first 10% of the validation split."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fercoFwLvSUx",
        "colab_type": "code",
        "outputId": "a9fe5017-fbb6-42de-ac32-67ee2dbc93ab",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 479
        }
      },
      "source": [
        "# Informations on the dataset (description, citation, size, splits, format...)\n",
        "# are provided in `dataset.info` (as a python dataclass)\n",
        "for key, value in asdict(dataset.info).items():\n",
        "    print('👉 ' + key + ': ' + str(value))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "👉 description: Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.\n",
            "\n",
            "👉 citation: @article{2016arXiv160605250R,\n",
            "       author = {{Rajpurkar}, Pranav and {Zhang}, Jian and {Lopyrev},\n",
            "                 Konstantin and {Liang}, Percy},\n",
            "        title = \"{SQuAD: 100,000+ Questions for Machine Comprehension of Text}\",\n",
            "      journal = {arXiv e-prints},\n",
            "         year = 2016,\n",
            "          eid = {arXiv:1606.05250},\n",
            "        pages = {arXiv:1606.05250},\n",
            "archivePrefix = {arXiv},\n",
            "       eprint = {1606.05250},\n",
            "}\n",
            "\n",
            "👉 homepage: https://rajpurkar.github.io/SQuAD-explorer/\n",
            "👉 license: \n",
            "👉 features: {'id': {'dtype': 'string', 'id': None, '_type': 'Value'}, 'title': {'dtype': 'string', 'id': None, '_type': 'Value'}, 'context': {'dtype': 'string', 'id': None, '_type': 'Value'}, 'question': {'dtype': 'string', 'id': None, '_type': 'Value'}, 'answers': {'feature': {'text': {'dtype': 'string', 'id': None, '_type': 'Value'}, 'answer_start': {'dtype': 'int32', 'id': None, '_type': 'Value'}}, 'length': -1, 'id': None, '_type': 'Sequence'}}\n",
            "👉 supervised_keys: None\n",
            "👉 builder_name: squad\n",
            "👉 config_name: plain_text\n",
            "👉 version: {'version_str': '1.0.0', 'description': 'New split API (https://tensorflow.org/datasets/splits)', 'nlp_version_to_prepare': None, 'major': 1, 'minor': 0, 'patch': 0}\n",
            "👉 splits: {'train': {'name': 'train', 'num_bytes': 79317110, 'num_examples': 87599, 'dataset_name': 'squad'}, 'validation': {'name': 'validation', 'num_bytes': 10472653, 'num_examples': 10570, 'dataset_name': 'squad'}}\n",
            "👉 download_checksums: {'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json': {'num_bytes': 30288272, 'checksum': '3527663986b8295af4f7fcdff1ba1ff3f72d07d61a20f487cb238a6ef92fd955'}, 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json': {'num_bytes': 4854279, 'checksum': '95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9'}}\n",
            "👉 download_size: 35142551\n",
            "👉 dataset_size: 89789763\n",
            "👉 size_in_bytes: 124932314\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GE0E87zsvSUz",
        "colab_type": "text"
      },
      "source": [
        "## Inspecting and using the dataset: elements, slices and columns"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DKf4YFnevSU0",
        "colab_type": "text"
      },
      "source": [
        "The returned `Dataset` object is a memory mapped dataset that behave similarly to a normal map-style dataset. It is backed by an Apache Arrow table which allows many interesting features."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tP1xPqSyvSU0",
        "colab_type": "code",
        "outputId": "13742347-8ff9-41c0-a387-1a5b8af03fd4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "print(dataset)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Dataset(schema: {'id': 'string', 'title': 'string', 'context': 'string', 'question': 'string', 'answers': 'struct<text: list<item: string>, answer_start: list<item: int32>>'}, num_rows: 1057)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aiO3rC8yvSU2",
        "colab_type": "text"
      },
      "source": [
        "You can query it's length and get items or slices like you would do normally with a python mapping."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xxLcdj2yvSU3",
        "colab_type": "code",
        "outputId": "cdaa61a8-d09a-4fc9-b741-f59006eb7da4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 374
        }
      },
      "source": [
        "from pprint import pprint\n",
        "\n",
        "print(f\"👉Dataset len(dataset): {len(dataset)}\")\n",
        "print(\"\\n👉First item 'dataset[0]':\")\n",
        "pprint(dataset[0])"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "👉Dataset len(dataset): 1057\n",
            "\n",
            "👉First item 'dataset[0]':\n",
            "{'answers': {'answer_start': [177, 177, 177],\n",
            "             'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos']},\n",
            " 'context': 'Super Bowl 50 was an American football game to determine the '\n",
            "            'champion of the National Football League (NFL) for the 2015 '\n",
            "            'season. The American Football Conference (AFC) champion Denver '\n",
            "            'Broncos defeated the National Football Conference (NFC) champion '\n",
            "            'Carolina Panthers 24–10 to earn their third Super Bowl title. The '\n",
            "            \"game was played on February 7, 2016, at Levi's Stadium in the San \"\n",
            "            'Francisco Bay Area at Santa Clara, California. As this was the '\n",
            "            '50th Super Bowl, the league emphasized the \"golden anniversary\" '\n",
            "            'with various gold-themed initiatives, as well as temporarily '\n",
            "            'suspending the tradition of naming each Super Bowl game with '\n",
            "            'Roman numerals (under which the game would have been known as '\n",
            "            '\"Super Bowl L\"), so that the logo could prominently feature the '\n",
            "            'Arabic numerals 50.',\n",
            " 'id': '56be4db0acb8001400a502ec',\n",
            " 'question': 'Which NFL team represented the AFC at Super Bowl 50?',\n",
            " 'title': 'Super_Bowl_50'}\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zk1WQ_cczP5w",
        "colab_type": "code",
        "outputId": "74bf043b-0037-4ba6-dbc1-10b991cea6c7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 748
        }
      },
      "source": [
        "# Or get slices with several examples:\n",
        "print(\"\\n👉Slice of the two items 'dataset[10:12]':\")\n",
        "pprint(dataset[10:12])"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\n",
            "👉Slice of the two items 'dataset[10:12]':\n",
            "OrderedDict([('id', ['56bea9923aeaaa14008c91bb', '56beace93aeaaa14008c91df']),\n",
            "             ('title', ['Super_Bowl_50', 'Super_Bowl_50']),\n",
            "             ('context',\n",
            "              ['Super Bowl 50 was an American football game to determine the '\n",
            "               'champion of the National Football League (NFL) for the 2015 '\n",
            "               'season. The American Football Conference (AFC) champion Denver '\n",
            "               'Broncos defeated the National Football Conference (NFC) '\n",
            "               'champion Carolina Panthers 24–10 to earn their third Super '\n",
            "               \"Bowl title. The game was played on February 7, 2016, at Levi's \"\n",
            "               'Stadium in the San Francisco Bay Area at Santa Clara, '\n",
            "               'California. As this was the 50th Super Bowl, the league '\n",
            "               'emphasized the \"golden anniversary\" with various gold-themed '\n",
            "               'initiatives, as well as temporarily suspending the tradition '\n",
            "               'of naming each Super Bowl game with Roman numerals (under '\n",
            "               'which the game would have been known as \"Super Bowl L\"), so '\n",
            "               'that the logo could prominently feature the Arabic numerals '\n",
            "               '50.',\n",
            "               'Super Bowl 50 was an American football game to determine the '\n",
            "               'champion of the National Football League (NFL) for the 2015 '\n",
            "               'season. The American Football Conference (AFC) champion Denver '\n",
            "               'Broncos defeated the National Football Conference (NFC) '\n",
            "               'champion Carolina Panthers 24–10 to earn their third Super '\n",
            "               \"Bowl title. The game was played on February 7, 2016, at Levi's \"\n",
            "               'Stadium in the San Francisco Bay Area at Santa Clara, '\n",
            "               'California. As this was the 50th Super Bowl, the league '\n",
            "               'emphasized the \"golden anniversary\" with various gold-themed '\n",
            "               'initiatives, as well as temporarily suspending the tradition '\n",
            "               'of naming each Super Bowl game with Roman numerals (under '\n",
            "               'which the game would have been known as \"Super Bowl L\"), so '\n",
            "               'that the logo could prominently feature the Arabic numerals '\n",
            "               '50.']),\n",
            "             ('question',\n",
            "              ['What day was the Super Bowl played on?',\n",
            "               'Who won Super Bowl 50?']),\n",
            "             ('answers',\n",
            "              [{'answer_start': [334, 334, 334],\n",
            "                'text': ['February 7, 2016', 'February 7', 'February 7, 2016']},\n",
            "               {'answer_start': [177, 177, 177],\n",
            "                'text': ['Denver Broncos',\n",
            "                         'Denver Broncos',\n",
            "                         'Denver Broncos']}])])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QXj2Qr5KvSU5",
        "colab_type": "code",
        "outputId": "fb1eb416-a73f-4fb7-d9dc-f6b6a11ffacd",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "source": [
        "# You can get a full column of the dataset by indexing with its name as a string:\n",
        "print(dataset['question'][:10])"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['Which NFL team represented the AFC at Super Bowl 50?', 'Which NFL team represented the NFC at Super Bowl 50?', 'Where did Super Bowl 50 take place?', 'Which NFL team won Super Bowl 50?', 'What color was used to emphasize the 50th anniversary of the Super Bowl?', 'What was the theme of Super Bowl 50?', 'What day was the game played on?', 'What is the AFC short for?', 'What was the theme of Super Bowl 50?', 'What does AFC stand for?']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Au7rqPMvSU7",
        "colab_type": "text"
      },
      "source": [
        "The `__getitem__` method will return different format depending on the type of query:\n",
        "\n",
        "- Items like `dataset[0]` are returned as dict of elements.\n",
        "- Slices like `dataset[10:20]` are returned as dict of lists of elements.\n",
        "- Columns like `dataset['question']` are returned as a list of elements.\n",
        "\n",
        "This may seems surprising at first but in our experiments it's actually a lot easier to use for data processing than returning the same format for each of these views on the dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6DB_y79cvSU8",
        "colab_type": "text"
      },
      "source": [
        "In particular, you can easily iterate along columns in slices, and also naturally permute consecutive indexings with identical results as showed here by permuting column indexing with elements and slices:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wjGocqArvSU9",
        "colab_type": "code",
        "outputId": "e0dfb174-8860-4b71-bc4b-b56bbc2f31b1",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "print(dataset[0]['question'] == dataset['question'][0])\n",
        "print(dataset[10:20]['context'] == dataset['context'][10:20])"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "True\n",
            "True\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b1-Kj1xQvSU_",
        "colab_type": "text"
      },
      "source": [
        "### Dataset are internally typed and structured\n",
        "\n",
        "The dataset is backed by one (or several) Apache Arrow tables which are typed and allows for fast retrieval and access as well as arbitrary-size memory mapping.\n",
        "\n",
        "This means respectively that the format for the dataset is clearly defined and that you can load datasets of arbitrary size without worrying about RAM memory limitation (basically the dataset take no space in RAM, it's directly read from drive when needed with fast IO access)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rAnp_RyPvSVA",
        "colab_type": "code",
        "outputId": "e154e878-4e33-4006-a762-cdc9c9cf50e4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        }
      },
      "source": [
        "# You can inspect the dataset column names and type \n",
        "print(dataset.column_names)\n",
        "print(dataset.schema)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['id', 'title', 'context', 'question', 'answers']\n",
            "id: string not null\n",
            "title: string not null\n",
            "context: string not null\n",
            "question: string not null\n",
            "answers: struct<text: list<item: string>, answer_start: list<item: int32>> not null\n",
            "  child 0, text: list<item: string>\n",
            "      child 0, item: string\n",
            "  child 1, answer_start: list<item: int32>\n",
            "      child 0, item: int32\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "au4v3mOQvSVC",
        "colab_type": "text"
      },
      "source": [
        "### Additional misc properties"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "efFhDWhlvSVC",
        "colab_type": "code",
        "outputId": "7b593e9d-eeb1-4b16-ea6b-46f4a5121e95",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        }
      },
      "source": [
        "# Datasets also have a bunch of properties you can access\n",
        "print(\"The number of bytes allocated on the drive is \", dataset.nbytes)\n",
        "print(\"For comparison, here is the number of bytes allocated in memory which can be\")\n",
        "print(\"accessed with `nlp.total_allocated_bytes()`: \", nlp.total_allocated_bytes())\n",
        "print(\"The number of rows\", dataset.num_rows)\n",
        "print(\"The number of columns\", dataset.num_columns)\n",
        "print(\"The shape (rows, columns)\", dataset.shape)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "The number of bytes allocated on the drive is  9855914\n",
            "For comparison, here is the number of bytes allocated in memory which can be\n",
            "accessed with `nlp.total_allocated_bytes()`:  0\n",
            "The number of rows 1057\n",
            "The number of columns 5\n",
            "The shape (rows, columns) (1057, 5)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o2_FBqAQvSVE",
        "colab_type": "text"
      },
      "source": [
        "### Additional misc methods"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SznY_XqGvSVF",
        "colab_type": "code",
        "outputId": "674b8e3e-6903-4cf3-b0a2-fe7a2cdc8345",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        }
      },
      "source": [
        "# We can list the unique elements in a column. This is done by the backend (so fast!)\n",
        "print(f\"dataset.unique('title'): {dataset.unique('title')}\")\n",
        "\n",
        "# This will drop the column 'id'\n",
        "dataset.drop('id')  # Remove column 'id'\n",
        "print(f\"After dataset.drop('id'), remaining columns are {dataset.column_names}\")\n",
        "\n",
        "# This will flatten nested columns (in 'answers' in our case)\n",
        "dataset.flatten()\n",
        "print(f\"After dataset.flatten(), column names are {dataset.column_names}\")\n",
        "\n",
        "# We can also \"dictionnary encode\" a column if many of it's elements are similar\n",
        "# This will reduce it's size by only storing the distinct elements (e.g. string)\n",
        "# It only has effect on the internal storage (no difference from a user point of view)\n",
        "dataset.dictionary_encode_column('title')"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "dataset.unique('title'): ['Super_Bowl_50', 'Warsaw']\n",
            "After dataset.drop('id'), remaining columns are ['title', 'context', 'question', 'answers']\n",
            "After dataset.flatten(), column names are ['title', 'context', 'question', 'answers.text', 'answers.answer_start']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QdyuKs4VvSVH",
        "colab_type": "text"
      },
      "source": [
        "## Cache\n",
        "\n",
        "`nlp` datasets are backed by Apache Arrow cache files which allows:\n",
        "- to load arbitrary large datasets by using [memory mapping](https://en.wikipedia.org/wiki/Memory-mapped_file) (as long as the datasets can fit on the drive)\n",
        "- to use a fast backend to process the dataset efficiently\n",
        "- to do smart caching by storing and reusing the results of operations performed on the drive\n",
        "\n",
        "Let's dive a bit in these parts now"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9fUcKwcbvSVH",
        "colab_type": "text"
      },
      "source": [
        "You can check the current cache files backing the dataset with the `.cache_file` property"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zu8TgHTYvSVI",
        "colab_type": "code",
        "outputId": "1bbb8030-58aa-4710-9be5-04faffa19467",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "dataset.cache_files"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "({'filename': '/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/squad-validation.arrow',\n",
              "  'skip': 0,\n",
              "  'take': 1057},)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LjeICK5GvSVK",
        "colab_type": "text"
      },
      "source": [
        "You can clean up the cache files in the current dataset directory (only keeping the currently used one) with `.cleanup_cache_files()`.\n",
        "\n",
        "Be careful that no other process is using some other cache files when running this command."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3_WNU3dwvSVL",
        "colab_type": "code",
        "outputId": "997bf2be-0589-4a45-de74-ace49ca66759",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "dataset.cleanup_cache_files()  # Returns the number of removed cache files"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Listing files in /root/.cache/huggingface/datasets/squad/plain_text/1.0.0\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Ox7ppKDvSVN",
        "colab_type": "text"
      },
      "source": [
        "## Modifying the dataset with `dataset.map`\n",
        "\n",
        "There is a powerful method `.map()` which is inspired by `tf.data` map method and that you can use to apply a function to each examples, independantly or in batch."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Yz2-27HevSVN",
        "colab_type": "code",
        "outputId": "8160a9e5-a5b1-4b38-d61a-701512f17b3f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 71
        }
      },
      "source": [
        "# `.map()` takes a callable accepting a dict as argument\n",
        "# (same dict as returned by dataset[i])\n",
        "# and iterate over the dataset by calling the function with each example.\n",
        "\n",
        "# Let's print the length of each `context` string in our subset of the dataset\n",
        "# (10% of the validation i.e. 1057 examples)\n",
        "\n",
        "dataset.map(lambda example: print(len(example['context']), end=','))"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\r0it [00:00, ?it/s]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,775,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,637,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,347,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,394,179,179,179,179,179,179,179,179,179,179,179,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,168,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,638,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,326,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,704,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,917,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1271,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,1166,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,2060,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,929,704,704,704,704,704,704,704,704,704,704,704,704,704,704,353,353,353,353,353,353,353,353,353,353,353,353,353,353,353,464,464,464,464,464,464,464,464,464,464,464,464,464,464,464,464,306,306,306,306,306,306,306,306,306,306,306,306,372,372,372,372,372,372,372,372,372,372,372,372,372,372,372,372,372,496,496,496,496,"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\r469it [00:00, 4689.78it/s]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "496,496,496,496,496,496,496,496,496,496,496,260,260,260,260,260,260,260,260,260,874,874,874,874,874,874,874,874,874,874,874,874,874,874,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,1025,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,176,782,782,782,782,782,782,782,782,782,782,782,782,782,782,782,782,536,536,536,536,536,536,536,536,536,666,666,666,666,666,666,666,666,666,666,666,666,666,666,666,666,666,495,495,495,495,495,495,495,495,495,495,495,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,385,441,441,441,441,441,441,441,441,441,441,441,357,357,357,357,357,357,357,357,357,296,296,296,296,296,296,296,296,296,296,644,644,644,644,644,644,644,644,644,644,644,644,644,644,644,644,644,804,804,804,804,804,804,804,804,804,804,804,397,397,397,397,397,397,397,397,397,397,397,397,397,397,360,360,360,360,360,360,360,973,973,973,973,973,973,973,973,973,973,973,973,973,973,263,263,263,263,263,263,263,263,263,263,263,568,568,568,568,568,568,568,568,568,568,568,264,264,264,264,264,264,264,264,264,264,264,264,264,264,264,892,892,892,892,892,892,892,892,892,892,892,206,206,206,206,206,489,489,489,489,489,489,489,489,489,489,489,489,489,181,181,181,181,181,181,181,181,181,181,181,181,531,531,531,531,531,531,531,531,531,531,531,531,664,664,664,664,664,664,664,664,664,664,664,664,664,664,672,672,672,672,672,672,672,672,672,672,672,672,672,672,858,858,858,858,858,858,858,858,858,858,858,858,634,634,634,634,634,634,634,634,634,634,634,634,634,634,891,891,891,891,891,891,891,891,891,891,891,891,891,488,488,488,488,488,488,488,488,488,488,488,488,942,942,942,942,942,942,942,942,942,942,942,942,942,942,942,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1162,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,1353,522,522,522,522,522,1643,1643,1643,1643,1643,628,628,628,628,628,758,758,758,758,758,883,883,883,883,883,559,559,559,559,559,603,603,603,603,631,631,631,631,631,626,626,626,626,626,541,541,541,541,541,795,795,795,795,795,591,591,591,591,591,568,568,568,568,568,536,536,536,536,536,575,575,575,575,575,571,571,571,571,571,641,641,641,641,641,665,665,665,665,665,1088,1088,1088,1088,1088,1619,1619,1619,1619,1619,939,939,939,939,939,"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\r1057it [00:00, 5765.00it/s]"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "865,865,865,865,865,711,711,711,711,711,831,831,831,831,831,501,501,501,501,501,676,676,676,676,676,854,854,854,854,854,784,784,784,784,784,641,641,641,641,641,544,544,544,544,544,918,918,918,918,918,763,763,763,763,763,906,906,906,906,906,632,632,632,632,632,869,869,869,869,869,1044,1044,1044,1044,1044,760,760,760,760,760,715,715,715,715,715,838,838,838,838,838,881,881,881,881,881,940,940,940,940,940,618,618,618,618,618,1205,1205,1205,534,534,534,534,534,757,757,757,757,757,1239,1239,1239,1239,1239,609,609,609,609,609,798,798,798,798,798,613,613,613,613,613,613,613,613,613,613,"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Dataset(schema: {'title': 'string', 'context': 'string', 'question': 'string', 'answers.text': 'list<item: string>', 'answers.answer_start': 'list<item: int32>'}, num_rows: 1057)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ta3celHnvSVP",
        "colab_type": "text"
      },
      "source": [
        "This is basically the same as doing\n",
        "\n",
        "```python\n",
        "for example in dataset:\n",
        "    function(example)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i_Ouw5gDvSVP",
        "colab_type": "text"
      },
      "source": [
        "The above example had no effect on the dataset because the method we supplied to `.map()` didn't return a `dict` or a `abc.Mapping` that could be used to update the examples in the dataset.\n",
        "\n",
        "In such a case, `.map()` will return the same dataset (`self`).\n",
        "\n",
        "Now let's see how we can use a method that actually modify the dataset."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cEnCi9DFvSVQ",
        "colab_type": "text"
      },
      "source": [
        "### Modifying the dataset example by example"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kA37VgZhvSVQ",
        "colab_type": "text"
      },
      "source": [
        "The main interest of `.map()` is to update and modify the content of the table and leverage smart caching and fast backend.\n",
        "\n",
        "To use `.map()` to update elements in the table you need to provide a function with the following signature: `function(example: dict) -> dict`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vUr65K-4vSVQ",
        "colab_type": "code",
        "outputId": "549a876e-bed8-4396-e96c-451876e98a32",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        }
      },
      "source": [
        "# Let's add a prefix 'My cute title: ' to each of our titles\n",
        "\n",
        "def add_prefix_to_title(example):\n",
        "    example['title'] = 'My cute title: ' + example['title']\n",
        "    return example\n",
        "\n",
        "dataset = dataset.map(add_prefix_to_title)\n",
        "\n",
        "print(dataset.unique('title'))"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-d99267fa77c05190a9cefa70b7a33564.arrow\n",
            "1057it [00:00, 22608.90it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 1057 examples in 905032 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-d99267fa77c05190a9cefa70b7a33564.arrow.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "['My cute title: Super_Bowl_50', 'My cute title: Warsaw']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FcZ_amDAvSVS",
        "colab_type": "text"
      },
      "source": [
        "This call to `.map()` compute and return the updated table. It will also store the updated table in a cache file indexed by the current state and the mapped function.\n",
        "\n",
        "A subsequent call to `.map()` (even in another python session) will reuse the cached file instead of recomputing the operation.\n",
        "\n",
        "You can test this by running again the previous cell, you will see that the result are directly loaded from the cache and not re-computed again.\n",
        "\n",
        "The updated dataset returned by `.map()` is (again) directly memory mapped from drive and not allocated in RAM."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Skbf8LUEvSVT",
        "colab_type": "text"
      },
      "source": [
        "The function you provide to `.map()` should accept an input with the format of an item of the dataset: `function(dataset[0])` and return a python dict.\n",
        "\n",
        "The columns and type of the outputs can be different than the input dict. In this case the new keys will be added as additional columns in the dataset.\n",
        "\n",
        "Bascially each dataset example dict is updated with the dictionary returned by the function like this: `example.update(function(example))`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d5De0CfTvSVT",
        "colab_type": "code",
        "outputId": "fdc2f7c1-b0cc-43b2-90f3-f72d68c2de80",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        }
      },
      "source": [
        "# Since the input example dict is updated with our function output dict,\n",
        "# we can actually just return the updated 'title' field\n",
        "dataset = dataset.map(lambda example: {'title': 'My cutest title: ' + example['title']})\n",
        "\n",
        "print(dataset.unique('title'))"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-b5fef4a553eb27e130e299608150b28a.arrow\n",
            "1057it [00:00, 21986.17it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 1057 examples in 923001 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-b5fef4a553eb27e130e299608150b28a.arrow.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "['My cutest title: My cute title: Super_Bowl_50', 'My cutest title: My cute title: Warsaw']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q5vny56-vSVV",
        "colab_type": "text"
      },
      "source": [
        "#### Removing columns\n",
        "You can also remove columns when running map with the `remove_columns=List[str]` argument."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-sPWnsz-vSVW",
        "colab_type": "code",
        "outputId": "15a0ef14-0995-4f0e-bb80-cc38d6087caf",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        }
      },
      "source": [
        "# This will remove the 'title' column while doing the update (after having send it the the mapped function so you can use it in your function!)\n",
        "dataset = dataset.map(lambda example: {'new_title': 'Wouhahh: ' + example['title']},\n",
        "                     remove_columns=['title'])\n",
        "\n",
        "print(dataset.column_names)\n",
        "print(dataset.unique('new_title'))"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-ca7a1d9079c4c5bcb746ee34fb9587b6.arrow\n",
            "1057it [00:00, 17295.53it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 1057 examples in 932514 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-ca7a1d9079c4c5bcb746ee34fb9587b6.arrow.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "['context', 'question', 'answers.text', 'answers.answer_start', 'new_title']\n",
            "['Wouhahh: My cutest title: My cute title: Super_Bowl_50', 'Wouhahh: My cutest title: My cute title: Warsaw']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G459HzD-vSVY",
        "colab_type": "text"
      },
      "source": [
        "#### Using examples indices\n",
        "With `with_indices=True`, dataset indices (from `0` to `len(dataset)`) will be supplied to the function which must thus have the following signature: `function(example: dict, indice: int) -> dict`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_kFL37R2vSVY",
        "colab_type": "code",
        "outputId": "29b8acdd-8a27-4ee3-bf2a-bc0c3d06e4e2",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 153
        }
      },
      "source": [
        "# This will add the index in the dataset to the 'question' field\n",
        "dataset = dataset.map(lambda example, idx: {'question': f'{idx}: ' + example['question']},\n",
        "                      with_indices=True)\n",
        "\n",
        "print('\\n'.join(dataset['question'][:5]))"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-0be0bac8de2ddc42bcc650f62286e93a.arrow\n",
            "1057it [00:00, 20453.13it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 1057 examples in 937746 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-0be0bac8de2ddc42bcc650f62286e93a.arrow.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "0: Which NFL team represented the AFC at Super Bowl 50?\n",
            "1: Which NFL team represented the NFC at Super Bowl 50?\n",
            "2: Where did Super Bowl 50 take place?\n",
            "3: Which NFL team won Super Bowl 50?\n",
            "4: What color was used to emphasize the 50th anniversary of the Super Bowl?\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xckhVEWFvSVb",
        "colab_type": "text"
      },
      "source": [
        "### Modifying the dataset with batched updates"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dzmicbSnvSVb",
        "colab_type": "text"
      },
      "source": [
        "`.map()` can also work with batch of examples (slices of the dataset).\n",
        "\n",
        "This is particularly interesting if you have a function that can handle batch of inputs like the tokenizers of HuggingFace `tokenizers`.\n",
        "\n",
        "To work on batched inputs set `batched=True` when calling `.map()` and supply a function with the following signature: `function(examples: Dict[List]) -> Dict[List]` or, if you use indices, `function(examples: Dict[List], indices: List[int]) -> Dict[List]`).\n",
        "\n",
        "Bascially, your function should accept an input with the format of a slice of the dataset: `function(dataset[:10])`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pxHbgSTL0itj",
        "colab_type": "code",
        "outputId": "280b22d8-1f2f-43f3-a2b3-51a45944f029",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 561
        }
      },
      "source": [
        "!pip install transformers"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting transformers\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/22/97/7db72a0beef1825f82188a4b923e62a146271ac2ced7928baa4d47ef2467/transformers-2.9.1-py3-none-any.whl (641kB)\n",
            "\u001b[K     |████████████████████████████████| 645kB 2.7MB/s \n",
            "\u001b[?25hRequirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n",
            "Collecting sacremoses\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
            "\u001b[K     |████████████████████████████████| 890kB 12.4MB/s \n",
            "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.4)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/3b/88/49e772d686088e1278766ad68a463513642a2a877487decbd691dec02955/sentencepiece-0.1.90-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n",
            "\u001b[K     |████████████████████████████████| 1.1MB 16.5MB/s \n",
            "\u001b[?25hCollecting tokenizers==0.7.0\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)\n",
            "\u001b[K     |████████████████████████████████| 3.8MB 15.9MB/s \n",
            "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9)\n",
            "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
            "Building wheels for collected packages: sacremoses\n",
            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=867f894393f82d69e6dbf5dd89fc230af1f78e8c900506dd4999586309d97f2e\n",
            "  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
            "Successfully built sacremoses\n",
            "Installing collected packages: sacremoses, sentencepiece, tokenizers, transformers\n",
            "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.90 tokenizers-0.7.0 transformers-2.9.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T7gpEg0yvSVc",
        "colab_type": "code",
        "outputId": "2bbb87a4-6ce2-43bc-87ec-2e166dc29e3a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 222,
          "referenced_widgets": [
            "b38b9470cc694dd2aa6d36982fecc94f",
            "6ad42f91046e4ba2a0fe329946bf57b6",
            "931883ddbf7345dfbfcf091c53c4861a",
            "db9de67daefd4065b390637c0294e2b6",
            "5379dd7943304dfdb1e96162e1f85832",
            "f557d0df4c664ec0a2014056ad70447a",
            "8b69d12f2711475e945c6a80b7516d5e",
            "36c77532c74b4f9388d8b182f9628f2b"
          ]
        }
      },
      "source": [
        "# Let's import a fast tokenizer that can work on batched inputs\n",
        "# (the 'Fast' tokenizers in HuggingFace)\n",
        "from transformers import BertTokenizerFast\n",
        "\n",
        "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.file_utils:PyTorch version 1.5.0+cu101 available.\n",
            "INFO:transformers.file_utils:TensorFlow version 2.2.0 available.\n",
            "INFO:filelock:Lock 140135286997736 acquired on /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1.lock\n",
            "INFO:transformers.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt not found in cache or force_download set to True, downloading to /root/.cache/torch/transformers/tmprrl6cjz3\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b38b9470cc694dd2aa6d36982fecc94f",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt in cache at /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n",
            "INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n",
            "INFO:filelock:Lock 140135286997736 released on /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1.lock\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fAmLTPC9vSVe",
        "colab_type": "code",
        "outputId": "ee6c1796-bffa-401f-95a3-85dd1635945d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 105
        }
      },
      "source": [
        "# Now let's batch tokenize our dataset 'context'\n",
        "dataset = dataset.map(lambda example: tokenizer.batch_encode_plus(example['context']),\n",
        "                      batched=True)\n",
        "\n",
        "print(\"dataset[0]\", dataset[0])"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-032674ed152efc659573f8fb4916bc54.arrow\n",
            "100%|██████████| 2/2 [00:00<00:00,  6.87it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 1057 examples in 4749270 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-032674ed152efc659573f8fb4916bc54.arrow.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "dataset[0] {'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.', 'question': '0: Which NFL team represented the AFC at Super Bowl 50?', 'answers.text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answers.answer_start': [177, 177, 177], 'new_title': 'Wouhahh: My cutest title: My cute title: Super_Bowl_50', 'input_ids': [101, 3198, 5308, 1851, 1108, 1126, 1237, 1709, 1342, 1106, 4959, 1103, 3628, 1104, 1103, 1305, 2289, 1453, 113, 4279, 114, 1111, 1103, 1410, 1265, 119, 1109, 1237, 2289, 3047, 113, 10402, 114, 3628, 7068, 14722, 2378, 1103, 1305, 2289, 3047, 113, 24743, 114, 3628, 2938, 13598, 1572, 782, 1275, 1106, 7379, 1147, 1503, 3198, 5308, 1641, 119, 1109, 1342, 1108, 1307, 1113, 1428, 128, 117, 1446, 117, 1120, 12388, 112, 188, 3339, 1107, 1103, 1727, 2948, 2410, 3894, 1120, 3364, 10200, 117, 1756, 119, 1249, 1142, 1108, 1103, 13163, 3198, 5308, 117, 1103, 2074, 13463, 1103, 107, 5404, 5453, 107, 1114, 1672, 2284, 118, 12005, 11751, 117, 1112, 1218, 1112, 7818, 28117, 20080, 16264, 1103, 3904, 1104, 10505, 1296, 3198, 5308, 1342, 1114, 2264, 183, 15447, 16179, 113, 1223, 1134, 1103, 1342, 1156, 1138, 1151, 1227, 1112, 107, 3198, 5308, 149, 107, 114, 117, 1177, 1115, 1103, 7998, 1180, 15199, 2672, 1103, 4944, 183, 15447, 16179, 1851, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kNaJdKskvSVf",
        "colab_type": "code",
        "outputId": "8cb8e12b-de3a-40b7-e44e-93dfce3223ae",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# we have added additional columns\n",
        "print(dataset.column_names)"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['context', 'question', 'answers.text', 'answers.answer_start', 'new_title', 'input_ids', 'token_type_ids', 'attention_mask']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "m3To8ztMvSVj",
        "colab_type": "code",
        "outputId": "6d643f61-46f3-4e81-ee2a-4fc6f5ea029c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "# Let show a more complex processing with the full preparation of the SQuAD dataset\n",
        "# for training a model from Transformers\n",
        "def convert_to_features(batch):\n",
        "    # Tokenize contexts and questions (as pairs of inputs)\n",
        "    # keep offset mappings for evaluation\n",
        "    input_pairs = list(zip(batch['context'], batch['question']))\n",
        "    encodings = tokenizer.batch_encode_plus(input_pairs,\n",
        "                                            pad_to_max_length=True,\n",
        "                                            return_offsets_mapping=True)\n",
        "\n",
        "    # Compute start and end tokens for labels\n",
        "    start_positions, end_positions = [], []\n",
        "    for i, (text, start) in enumerate(zip(batch['answers.text'], batch['answers.answer_start'])):\n",
        "        first_char = start[0]\n",
        "        last_char = first_char + len(text[0]) - 1\n",
        "        start_positions.append(encodings.char_to_token(i, first_char))\n",
        "        end_positions.append(encodings.char_to_token(i, last_char))\n",
        "\n",
        "    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})\n",
        "    return encodings\n",
        "\n",
        "dataset = dataset.map(convert_to_features, batched=True)"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-4dc863be0d0c1e9e15749b73b0061784.arrow\n",
            "100%|██████████| 2/2 [00:01<00:00,  1.54it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 1057 examples in 21643250 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-4dc863be0d0c1e9e15749b73b0061784.arrow.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KBnmSa46vSVl",
        "colab_type": "code",
        "outputId": "824af1cc-11f5-4781-ccff-2d6b257eed33",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "# Now our dataset comprise the labels for the start and end position\n",
        "# as well as the offsets for converting back tokens\n",
        "# in span of the original string for evaluation\n",
        "print(\"column_names\", dataset.column_names)\n",
        "print(\"start_positions\", dataset[:5]['start_positions'])"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "column_names ['context', 'question', 'answers.text', 'answers.answer_start', 'new_title', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions']\n",
            "start_positions [34, 45, 80, 34, 98]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NzOXxNzQvSVo",
        "colab_type": "text"
      },
      "source": [
        "## Formating outputs for numpy/torch/tensorflow\n",
        "\n",
        "Now that we have tokenized our inputs, we probably want to use this dataset in a `torch.Dataloader` or a `tf.data.Dataset`.\n",
        "\n",
        "To be able to do this we need to tweak two things:\n",
        "\n",
        "- format the indexing (`__getitem__`) to return numpy/pytorch/tensorflow tensors, instead of python objects, and probably\n",
        "- format the indexing (`__getitem__`) to return only the subset of the columns that we need for our model inputs.\n",
        "\n",
        "  We don't want the columns `id` or `title` as inputs to train our model, but we could still want to keep them in the dataset, for instance for the evaluation of the model.\n",
        "    \n",
        "This is handled by the `.set_format(type: Union[None, str], columns: Union[None, str, List[str]])` where:\n",
        "\n",
        "- `type` define the return type for our dataset `__getitem__` method and is one of `[None, 'numpy', 'pandas', 'torch', 'tensorflow']` (`None` means return python objects), and\n",
        "- `columns` define the columns returned by `__getitem__` and takes the name of a column in the dataset or a list of columns to return (`None` means return all columns)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aU2h_qQDvSVo",
        "colab_type": "code",
        "outputId": "677d9ee9-76a9-40af-9c4c-6a720197609c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 139
        }
      },
      "source": [
        "columns_to_return = ['input_ids', 'token_type_ids', 'attention_mask',\n",
        "                     'start_positions', 'end_positions']\n",
        "\n",
        "dataset.set_format(type='torch',\n",
        "                   columns=columns_to_return)\n",
        "\n",
        "# Our dataset indexing output is now ready for being used in a pytorch dataloader\n",
        "print('\\n'.join([' '.join((n, str(type(t)), str(t.shape))) for n, t in dataset[:10].items()]))"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch for ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'] columns  (when key is int or slice) and don't output other (un-formated) columns.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "input_ids <class 'torch.Tensor'> torch.Size([10, 451])\n",
            "token_type_ids <class 'torch.Tensor'> torch.Size([10, 451])\n",
            "attention_mask <class 'torch.Tensor'> torch.Size([10, 451])\n",
            "start_positions <class 'torch.Tensor'> torch.Size([10])\n",
            "end_positions <class 'torch.Tensor'> torch.Size([10])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Wj1ukGIuvSVq",
        "colab_type": "code",
        "outputId": "ccacd424-17c8-417a-8ae7-0c9e25491e30",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# Note that the columns are not removed from the dataset, just not returned when calling __getitem__\n",
        "# Similarly the inner type of the dataset is not changed to torch.Tensor, the conversion and filtering is done on-the-fly when querying the dataset\n",
        "print(dataset.column_names)"
      ],
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['context', 'question', 'answers.text', 'answers.answer_start', 'new_title', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'start_positions', 'end_positions']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pWmmUnlpvSVs",
        "colab_type": "code",
        "outputId": "85ed5497-cd09-43ae-ca78-4a206680ca24",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 221
        }
      },
      "source": [
        "# We can remove the formating with `.reset_format()`\n",
        "# or, identically, a call to `.set_format()` with no arguments\n",
        "dataset.reset_format()\n",
        "\n",
        "print('\\n'.join([' '.join((n, str(type(t)))) for n, t in dataset[:10].items()]))"
      ],
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to python objects for no columns  (when key is int or slice) and don't output other (un-formated) columns.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "context <class 'list'>\n",
            "question <class 'list'>\n",
            "answers.text <class 'list'>\n",
            "answers.answer_start <class 'list'>\n",
            "new_title <class 'list'>\n",
            "input_ids <class 'list'>\n",
            "token_type_ids <class 'list'>\n",
            "attention_mask <class 'list'>\n",
            "offset_mapping <class 'list'>\n",
            "start_positions <class 'list'>\n",
            "end_positions <class 'list'>\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VyUOA07svSVu",
        "colab_type": "code",
        "outputId": "dbff43a2-7012-45bb-cc7c-26750958c2bc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 238
        }
      },
      "source": [
        "# The current format can be checked with `.format`,\n",
        "# which is a dict of the type and formating\n",
        "dataset.format"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'columns': ['context',\n",
              "  'question',\n",
              "  'answers.text',\n",
              "  'answers.answer_start',\n",
              "  'new_title',\n",
              "  'input_ids',\n",
              "  'token_type_ids',\n",
              "  'attention_mask',\n",
              "  'offset_mapping',\n",
              "  'start_positions',\n",
              "  'end_positions'],\n",
              " 'output_all_columns': False,\n",
              " 'type': 'python'}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xyi2eMeSvSVv",
        "colab_type": "text"
      },
      "source": [
        "# Wrapping this all up (PyTorch)\n",
        "\n",
        "Let's wrap this all up with the full code to load and prepare SQuAD for training a PyTorch model from HuggingFace `transformers` library.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QvExTIZWvSVw",
        "colab_type": "code",
        "outputId": "980a3dc8-1898-443f-de59-97c5a611b68c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 360
        }
      },
      "source": [
        "import nlp\n",
        "import torch \n",
        "from transformers import BertTokenizerFast\n",
        "\n",
        "# Load our training dataset and tokenizer\n",
        "dataset = nlp.load_dataset('squad')\n",
        "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')\n",
        "\n",
        "def get_correct_alignement(context, answer):\n",
        "    \"\"\" Some original examples in SQuAD have indices wrong by 1 or 2 character. We test and fix this here. \"\"\"\n",
        "    gold_text = answer['text'][0]\n",
        "    start_idx = answer['answer_start'][0]\n",
        "    end_idx = start_idx + len(gold_text)\n",
        "    if context[start_idx:end_idx] == gold_text:\n",
        "        return start_idx, end_idx       # When the gold label position is good\n",
        "    elif context[start_idx-1:end_idx-1] == gold_text:\n",
        "        return start_idx-1, end_idx-1   # When the gold label is off by one character\n",
        "    elif context[start_idx-2:end_idx-2] == gold_text:\n",
        "        return start_idx-2, end_idx-2   # When the gold label is off by two character\n",
        "    else:\n",
        "        raise ValueError()\n",
        "\n",
        "# Tokenize our training dataset\n",
        "def convert_to_features(example_batch):\n",
        "    # Tokenize contexts and questions (as pairs of inputs)\n",
        "    input_pairs = list(zip(example_batch['context'], example_batch['question']))\n",
        "    encodings = tokenizer.batch_encode_plus(input_pairs, pad_to_max_length=True)\n",
        "\n",
        "    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methodes.\n",
        "    start_positions, end_positions = [], []\n",
        "    for i, (context, answer) in enumerate(zip(example_batch['context'], example_batch['answers'])):\n",
        "        start_idx, end_idx = get_correct_alignement(context, answer)\n",
        "        start_positions.append(encodings.char_to_token(i, start_idx))\n",
        "        end_positions.append(encodings.char_to_token(i, end_idx-1))\n",
        "    encodings.update({'start_positions': start_positions,\n",
        "                      'end_positions': end_positions})\n",
        "    return encodings\n",
        "\n",
        "dataset['train'] = dataset['train'].map(convert_to_features, batched=True)\n",
        "\n",
        "# Format our dataset to outputs torch.Tensor to train a pytorch model\n",
        "columns = ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions']\n",
        "dataset['train'].set_format(type='torch', columns=columns)\n",
        "\n",
        "# Instantiate a PyTorch Dataloader around our dataset\n",
        "dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=8)"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.load:Checking /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py for additional imports.\n",
            "INFO:filelock:Lock 140132889271152 acquired on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad\n",
            "INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670\n",
            "INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/squad.py\n",
            "INFO:nlp.load:Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/dataset_infos.json\n",
            "INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/squad.json\n",
            "INFO:filelock:Lock 140132889271152 released on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.builder:No config specified, defaulting to first: squad/plain_text\n",
            "INFO:nlp.info:Loading Dataset Infos from /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670\n",
            "INFO:nlp.builder:Overwrite dataset info from restored data version.\n",
            "INFO:nlp.info:Loading Dataset info from /root/.cache/huggingface/datasets/squad/plain_text/1.0.0\n",
            "INFO:nlp.builder:Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0)\n",
            "INFO:nlp.builder:Constructing Dataset for split None, from /root/.cache/huggingface/datasets/squad/plain_text/1.0.0\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n",
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-e59fce014c308c7a796f7166088ad089.arrow\n",
            "100%|██████████| 88/88 [00:39<00:00,  2.25it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 87599 examples in 1098782410 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-e59fce014c308c7a796f7166088ad089.arrow.\n",
            "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch for ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'] columns  (when key is int or slice) and don't output other (un-formated) columns.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4mHnwMx2vSVx",
        "colab_type": "code",
        "outputId": "80e45197-29bf-43f1-c92c-babe46623ed4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 866,
          "referenced_widgets": [
            "454a7d19cdb04b7e93efec4e2b374345",
            "adb39c17c61947ce9fb241311e929135",
            "557a58848c2a498882429ad3b3efe29f",
            "ea2a6e311e164c00b7e9801a58aa44c4",
            "72b94b71b4214c53891769dc1bd1d15d",
            "cf063e09a69243d3af2f6a97f4f6417e",
            "fafb257a0fe24cc9a8ea3e1a0b87e42b",
            "0e9a15d300d547dbae5f6fc6a1df1782",
            "049655d977ec42ddaa2079422cc0a635",
            "4807314ef3b342a1baf6dd2e0cd16cd5",
            "508c6262c25a41de89be1499c973a126",
            "a45ffab324f94d4882863e56b3d43678",
            "7d3af3d2939c49e2b113d92e8573003f",
            "5627b08bdb1f4a9fa2d58c63481e1129",
            "259b90b57d064c938b47d61c279ae237",
            "ef365e73c93340a983ace25acfa41c17"
          ]
        }
      },
      "source": [
        "# Let's load a pretrained Bert model and a simple optimizer\n",
        "from transformers import BertForQuestionAnswering\n",
        "\n",
        "model = BertForQuestionAnswering.from_pretrained('distilbert-base-cased')\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:filelock:Lock 140132924425608 acquired on /root/.cache/torch/transformers/774d52b0be7c2f621ac9e64708a8b80f22059f6d0e264e1bdc4f4d71c386c4ea.f44aaaab97e2ee0f8d9071a5cd694e19bf664237a92aea20ebe04ddb7097b494.lock\n",
            "INFO:transformers.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-cased-config.json not found in cache or force_download set to True, downloading to /root/.cache/torch/transformers/tmp9xkcwpsh\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "454a7d19cdb04b7e93efec4e2b374345",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411.0, style=ProgressStyle(description_…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.file_utils:storing https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-cased-config.json in cache at /root/.cache/torch/transformers/774d52b0be7c2f621ac9e64708a8b80f22059f6d0e264e1bdc4f4d71c386c4ea.f44aaaab97e2ee0f8d9071a5cd694e19bf664237a92aea20ebe04ddb7097b494\n",
            "INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/774d52b0be7c2f621ac9e64708a8b80f22059f6d0e264e1bdc4f4d71c386c4ea.f44aaaab97e2ee0f8d9071a5cd694e19bf664237a92aea20ebe04ddb7097b494\n",
            "INFO:filelock:Lock 140132924425608 released on /root/.cache/torch/transformers/774d52b0be7c2f621ac9e64708a8b80f22059f6d0e264e1bdc4f4d71c386c4ea.f44aaaab97e2ee0f8d9071a5cd694e19bf664237a92aea20ebe04ddb7097b494.lock\n",
            "INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-cased-config.json from cache at /root/.cache/torch/transformers/774d52b0be7c2f621ac9e64708a8b80f22059f6d0e264e1bdc4f4d71c386c4ea.f44aaaab97e2ee0f8d9071a5cd694e19bf664237a92aea20ebe04ddb7097b494\n",
            "INFO:transformers.configuration_utils:Model config BertConfig {\n",
            "  \"activation\": \"gelu\",\n",
            "  \"attention_dropout\": 0.1,\n",
            "  \"attention_probs_dropout_prob\": 0.1,\n",
            "  \"dim\": 768,\n",
            "  \"dropout\": 0.1,\n",
            "  \"hidden_act\": \"gelu\",\n",
            "  \"hidden_dim\": 3072,\n",
            "  \"hidden_dropout_prob\": 0.1,\n",
            "  \"hidden_size\": 768,\n",
            "  \"initializer_range\": 0.02,\n",
            "  \"intermediate_size\": 3072,\n",
            "  \"layer_norm_eps\": 1e-12,\n",
            "  \"max_position_embeddings\": 512,\n",
            "  \"model_type\": \"bert\",\n",
            "  \"n_heads\": 12,\n",
            "  \"n_layers\": 6,\n",
            "  \"num_attention_heads\": 12,\n",
            "  \"num_hidden_layers\": 12,\n",
            "  \"output_past\": true,\n",
            "  \"pad_token_id\": 0,\n",
            "  \"qa_dropout\": 0.1,\n",
            "  \"seq_classif_dropout\": 0.2,\n",
            "  \"sinusoidal_pos_embds\": false,\n",
            "  \"tie_weights_\": true,\n",
            "  \"type_vocab_size\": 2,\n",
            "  \"vocab_size\": 28996\n",
            "}\n",
            "\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:filelock:Lock 140132908731584 acquired on /root/.cache/torch/transformers/185eb053d63bc5c2d6994e4b2a8e5eb59f31af90db9c5fae5e38c32a986462cb.857b7d17ad0bfaa2eec50caf481575bab1073303fef16bd5f29bc5248b2b8c7d.lock\n",
            "INFO:transformers.file_utils:https://cdn.huggingface.co/distilbert-base-cased-pytorch_model.bin not found in cache or force_download set to True, downloading to /root/.cache/torch/transformers/tmpunvn5ujm\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "049655d977ec42ddaa2079422cc0a635",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263273408.0, style=ProgressStyle(descri…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.file_utils:storing https://cdn.huggingface.co/distilbert-base-cased-pytorch_model.bin in cache at /root/.cache/torch/transformers/185eb053d63bc5c2d6994e4b2a8e5eb59f31af90db9c5fae5e38c32a986462cb.857b7d17ad0bfaa2eec50caf481575bab1073303fef16bd5f29bc5248b2b8c7d\n",
            "INFO:transformers.file_utils:creating metadata file for /root/.cache/torch/transformers/185eb053d63bc5c2d6994e4b2a8e5eb59f31af90db9c5fae5e38c32a986462cb.857b7d17ad0bfaa2eec50caf481575bab1073303fef16bd5f29bc5248b2b8c7d\n",
            "INFO:filelock:Lock 140132908731584 released on /root/.cache/torch/transformers/185eb053d63bc5c2d6994e4b2a8e5eb59f31af90db9c5fae5e38c32a986462cb.857b7d17ad0bfaa2eec50caf481575bab1073303fef16bd5f29bc5248b2b8c7d.lock\n",
            "INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/distilbert-base-cased-pytorch_model.bin from cache at /root/.cache/torch/transformers/185eb053d63bc5c2d6994e4b2a8e5eb59f31af90db9c5fae5e38c32a986462cb.857b7d17ad0bfaa2eec50caf481575bab1073303fef16bd5f29bc5248b2b8c7d\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "INFO:transformers.modeling_utils:Weights of BertForQuestionAnswering not initialized from pretrained model: ['embeddings.word_embeddings.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.LayerNorm.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.0.output.dense.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.1.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.3.output.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.4.attention.output.dense.bias', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.4.output.LayerNorm.weight', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.5.attention.self.query.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.output.dense.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.6.attention.self.value.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.7.output.dense.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.8.attention.self.key.weight', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.8.output.dense.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.9.attention.self.query.bias', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.9.attention.self.value.weight', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.10.output.dense.weight', 'encoder.layer.10.output.dense.bias', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.11.attention.self.key.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.11.output.dense.bias', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'pooler.dense.weight', 'pooler.dense.bias', 'qa_outputs.bias', 'qa_outputs.weight']\n",
            "INFO:transformers.modeling_utils:Weights from pretrained model not used in BertForQuestionAnswering: ['distilbert.embeddings.word_embeddings.weight', 'distilbert.embeddings.position_embeddings.weight', 'distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias', 'distilbert.transformer.layer.0.attention.q_lin.weight', 'distilbert.transformer.layer.0.attention.q_lin.bias', 'distilbert.transformer.layer.0.attention.k_lin.weight', 'distilbert.transformer.layer.0.attention.k_lin.bias', 'distilbert.transformer.layer.0.attention.v_lin.weight', 'distilbert.transformer.layer.0.attention.v_lin.bias', 'distilbert.transformer.layer.0.attention.out_lin.weight', 'distilbert.transformer.layer.0.attention.out_lin.bias', 'distilbert.transformer.layer.0.sa_layer_norm.weight', 'distilbert.transformer.layer.0.sa_layer_norm.bias', 'distilbert.transformer.layer.0.ffn.lin1.weight', 'distilbert.transformer.layer.0.ffn.lin1.bias', 'distilbert.transformer.layer.0.ffn.lin2.weight', 'distilbert.transformer.layer.0.ffn.lin2.bias', 'distilbert.transformer.layer.0.output_layer_norm.weight', 'distilbert.transformer.layer.0.output_layer_norm.bias', 'distilbert.transformer.layer.1.attention.q_lin.weight', 'distilbert.transformer.layer.1.attention.q_lin.bias', 'distilbert.transformer.layer.1.attention.k_lin.weight', 'distilbert.transformer.layer.1.attention.k_lin.bias', 'distilbert.transformer.layer.1.attention.v_lin.weight', 'distilbert.transformer.layer.1.attention.v_lin.bias', 'distilbert.transformer.layer.1.attention.out_lin.weight', 'distilbert.transformer.layer.1.attention.out_lin.bias', 'distilbert.transformer.layer.1.sa_layer_norm.weight', 'distilbert.transformer.layer.1.sa_layer_norm.bias', 'distilbert.transformer.layer.1.ffn.lin1.weight', 'distilbert.transformer.layer.1.ffn.lin1.bias', 'distilbert.transformer.layer.1.ffn.lin2.weight', 'distilbert.transformer.layer.1.ffn.lin2.bias', 'distilbert.transformer.layer.1.output_layer_norm.weight', 'distilbert.transformer.layer.1.output_layer_norm.bias', 'distilbert.transformer.layer.2.attention.q_lin.weight', 'distilbert.transformer.layer.2.attention.q_lin.bias', 'distilbert.transformer.layer.2.attention.k_lin.weight', 'distilbert.transformer.layer.2.attention.k_lin.bias', 'distilbert.transformer.layer.2.attention.v_lin.weight', 'distilbert.transformer.layer.2.attention.v_lin.bias', 'distilbert.transformer.layer.2.attention.out_lin.weight', 'distilbert.transformer.layer.2.attention.out_lin.bias', 'distilbert.transformer.layer.2.sa_layer_norm.weight', 'distilbert.transformer.layer.2.sa_layer_norm.bias', 'distilbert.transformer.layer.2.ffn.lin1.weight', 'distilbert.transformer.layer.2.ffn.lin1.bias', 'distilbert.transformer.layer.2.ffn.lin2.weight', 'distilbert.transformer.layer.2.ffn.lin2.bias', 'distilbert.transformer.layer.2.output_layer_norm.weight', 'distilbert.transformer.layer.2.output_layer_norm.bias', 'distilbert.transformer.layer.3.attention.q_lin.weight', 'distilbert.transformer.layer.3.attention.q_lin.bias', 'distilbert.transformer.layer.3.attention.k_lin.weight', 'distilbert.transformer.layer.3.attention.k_lin.bias', 'distilbert.transformer.layer.3.attention.v_lin.weight', 'distilbert.transformer.layer.3.attention.v_lin.bias', 'distilbert.transformer.layer.3.attention.out_lin.weight', 'distilbert.transformer.layer.3.attention.out_lin.bias', 'distilbert.transformer.layer.3.sa_layer_norm.weight', 'distilbert.transformer.layer.3.sa_layer_norm.bias', 'distilbert.transformer.layer.3.ffn.lin1.weight', 'distilbert.transformer.layer.3.ffn.lin1.bias', 'distilbert.transformer.layer.3.ffn.lin2.weight', 'distilbert.transformer.layer.3.ffn.lin2.bias', 'distilbert.transformer.layer.3.output_layer_norm.weight', 'distilbert.transformer.layer.3.output_layer_norm.bias', 'distilbert.transformer.layer.4.attention.q_lin.weight', 'distilbert.transformer.layer.4.attention.q_lin.bias', 'distilbert.transformer.layer.4.attention.k_lin.weight', 'distilbert.transformer.layer.4.attention.k_lin.bias', 'distilbert.transformer.layer.4.attention.v_lin.weight', 'distilbert.transformer.layer.4.attention.v_lin.bias', 'distilbert.transformer.layer.4.attention.out_lin.weight', 'distilbert.transformer.layer.4.attention.out_lin.bias', 'distilbert.transformer.layer.4.sa_layer_norm.weight', 'distilbert.transformer.layer.4.sa_layer_norm.bias', 'distilbert.transformer.layer.4.ffn.lin1.weight', 'distilbert.transformer.layer.4.ffn.lin1.bias', 'distilbert.transformer.layer.4.ffn.lin2.weight', 'distilbert.transformer.layer.4.ffn.lin2.bias', 'distilbert.transformer.layer.4.output_layer_norm.weight', 'distilbert.transformer.layer.4.output_layer_norm.bias', 'distilbert.transformer.layer.5.attention.q_lin.weight', 'distilbert.transformer.layer.5.attention.q_lin.bias', 'distilbert.transformer.layer.5.attention.k_lin.weight', 'distilbert.transformer.layer.5.attention.k_lin.bias', 'distilbert.transformer.layer.5.attention.v_lin.weight', 'distilbert.transformer.layer.5.attention.v_lin.bias', 'distilbert.transformer.layer.5.attention.out_lin.weight', 'distilbert.transformer.layer.5.attention.out_lin.bias', 'distilbert.transformer.layer.5.sa_layer_norm.weight', 'distilbert.transformer.layer.5.sa_layer_norm.bias', 'distilbert.transformer.layer.5.ffn.lin1.weight', 'distilbert.transformer.layer.5.ffn.lin1.bias', 'distilbert.transformer.layer.5.ffn.lin2.weight', 'distilbert.transformer.layer.5.ffn.lin2.bias', 'distilbert.transformer.layer.5.output_layer_norm.weight', 'distilbert.transformer.layer.5.output_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "biqDH9vpvSVz",
        "colab_type": "code",
        "outputId": "af298e23-ba60-4e6f-bf73-0ed5f5d45030",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        }
      },
      "source": [
        "# Now let's train our model\n",
        "\n",
        "model.train()\n",
        "for i, batch in enumerate(dataloader):\n",
        "    outputs = model(**batch)\n",
        "    loss = outputs[0]\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    model.zero_grad()\n",
        "    print(f'Step {i} - loss: {loss:.3}')\n",
        "    if i > 3:\n",
        "        break"
      ],
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Step 0 - loss: 5.97\n",
            "Step 1 - loss: 5.34\n",
            "Step 2 - loss: 5.12\n",
            "Step 3 - loss: 5.44\n",
            "Step 4 - loss: 4.87\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kxZQ9Ms_vSV1",
        "colab_type": "text"
      },
      "source": [
        "# Wrapping this all up (Tensorflow)\n",
        "\n",
        "Let's wrap this all up with the full code to load and prepare SQuAD for training a Tensorflow model (works only from the version 2.2.0)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZE8VSTYovSV2",
        "colab_type": "code",
        "outputId": "b58d3399-49d8-4b8b-8e93-1b20e17eb081",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 718
        }
      },
      "source": [
        "import tensorflow as tf\n",
        "import nlp\n",
        "from transformers import BertTokenizerFast\n",
        "\n",
        "# Load our training dataset and tokenizer\n",
        "train_tf_dataset = nlp.load_dataset('squad', split=\"train\")\n",
        "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')\n",
        "\n",
        "# Tokenize our training dataset\n",
        "# The only one diff here is that start_positions and end_positions\n",
        "# must be single dim list => [[23], [45] ...]\n",
        "# instead of => [23, 45 ...]\n",
        "def convert_to_tf_features(example_batch):\n",
        "    # Tokenize contexts and questions (as pairs of inputs)\n",
        "    input_pairs = list(zip(example_batch['context'], example_batch['question']))\n",
        "    encodings = tokenizer.batch_encode_plus(input_pairs, pad_to_max_length=True)\n",
        "\n",
        "    # Compute start and end tokens for labels using Transformers's fast tokenizers alignement methods.\n",
        "    start_positions, end_positions = [], []\n",
        "    for i, (context, answer) in enumerate(zip(example_batch['context'], example_batch['answers'])):\n",
        "        start_idx, end_idx = get_correct_alignement(context, answer)\n",
        "        start_positions.append([encodings.char_to_token(i, start_idx)])\n",
        "        end_positions.append([encodings.char_to_token(i, end_idx-1)])\n",
        "    \n",
        "    if start_positions and end_positions:\n",
        "      encodings.update({'start_positions': start_positions,\n",
        "                        'end_positions': end_positions})\n",
        "    return encodings\n",
        "\n",
        "train_tf_dataset = train_tf_dataset.map(convert_to_tf_features, batched=True)\n",
        "\n",
        "columns = ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions']\n",
        "train_tf_dataset.set_format(type='tensorflow', columns=columns)\n",
        "features = {x: train_tf_dataset[x] for x in columns[:3]} \n",
        "labels = {\"output_1\": train_tf_dataset[\"start_positions\"]}\n",
        "labels[\"output_2\"] = train_tf_dataset[\"end_positions\"]\n",
        "tfdataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(8)"
      ],
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO:nlp.load:Checking /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py for additional imports.\n",
            "INFO:filelock:Lock 140132926100256 acquired on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad\n",
            "INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670\n",
            "INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/squad.py\n",
            "INFO:nlp.load:Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/dataset_infos.json to /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/dataset_infos.json\n",
            "INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/squad/squad.py at /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670/squad.json\n",
            "INFO:filelock:Lock 140132926100256 released on /root/.cache/huggingface/datasets/09ec6948d9db29db9a2dcd08df97ac45bccfa6aa104ea62d73c97fa4aaa5cd6c.f373b0de1570ca81b50bb03bd371604f7979e35de2cfcf2a3b4521d0b3104d9b.py.lock\n",
            "INFO:nlp.builder:No config specified, defaulting to first: squad/plain_text\n",
            "INFO:nlp.info:Loading Dataset Infos from /usr/local/lib/python3.6/dist-packages/nlp/datasets/squad/c0327553d80335e3a3283527f64d9778df7ad04ab28f38148d072782712bb670\n",
            "INFO:nlp.builder:Overwrite dataset info from restored data version.\n",
            "INFO:nlp.info:Loading Dataset info from /root/.cache/huggingface/datasets/squad/plain_text/1.0.0\n",
            "INFO:nlp.builder:Reusing dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0)\n",
            "INFO:nlp.builder:Constructing Dataset for split train, from /root/.cache/huggingface/datasets/squad/plain_text/1.0.0\n",
            "INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n",
            "INFO:nlp.arrow_dataset:Caching processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-9b8a4849358a2045b27513d017ea0dbe.arrow\n",
            "100%|██████████| 88/88 [00:36<00:00,  2.42it/s]\n",
            "INFO:nlp.arrow_writer:Done writing 87599 examples in 1099483906 bytes /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/cache-9b8a4849358a2045b27513d017ea0dbe.arrow.\n",
            "INFO:nlp.arrow_dataset:Set __getitem__(key) output type to tensorflow for ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'] columns  (when key is int or slice) and don't output other (un-formated) columns.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-36-fe6cfac490d0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'token_type_ids'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'attention_mask'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start_positions'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'end_positions'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_format\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'tensorflow'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"output_1\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"start_positions\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"output_2\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"end_positions\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-36-fe6cfac490d0>\u001b[0m in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'token_type_ids'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'attention_mask'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start_positions'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'end_positions'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_format\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'tensorflow'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcolumns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcolumns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m\"output_1\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"start_positions\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"output_2\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_tf_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"end_positions\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/nlp/arrow_dataset.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m    336\u001b[0m             \u001b[0mformat_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_format_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    337\u001b[0m             \u001b[0mformat_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_format_columns\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 338\u001b[0;31m             \u001b[0moutput_all_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output_all_columns\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    339\u001b[0m         )\n\u001b[1;32m    340\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/nlp/arrow_dataset.py\u001b[0m in \u001b[0;36m_getitem\u001b[0;34m(self, key, format_type, format_columns, output_all_columns)\u001b[0m\n\u001b[1;32m    311\u001b[0m                         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_numpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchunks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    312\u001b[0m                     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 313\u001b[0;31m                         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_outputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_pylist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    314\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    315\u001b[0m                     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_pylist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y0dfw8K8vSV4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Let's load a pretrained TF2 Bert model and a simple optimizer\n",
        "from transformers import TFBertForQuestionAnswering\n",
        "\n",
        "model = TFBertForQuestionAnswering.from_pretrained(\"bert-base-cased\")\n",
        "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE, from_logits=True)\n",
        "opt = tf.keras.optimizers.Adam(learning_rate=3e-5)\n",
        "model.compile(optimizer=opt,\n",
        "              loss={'output_1': loss_fn, 'output_2': loss_fn},\n",
        "              loss_weights={'output_1': 1., 'output_2': 1.},\n",
        "              metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TcYtiykmvSV6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Now let's train our model\n",
        "\n",
        "model.fit(tfdataset, epochs=1, steps_per_epoch=3)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eREDXWP6vSV8",
        "colab_type": "text"
      },
      "source": [
        "# Metrics API\n",
        "\n",
        "`nlp` also provides easy access and sharing of metrics.\n",
        "\n",
        "This aspect of the library is still experimental and the API may still evolve more than the datasets API.\n",
        "\n",
        "Like datasets, metrics are added as small scripts wrapping common metrics in a common API.\n",
        "\n",
        "There are several reason you may want to use metrics with `nlp` and in particular:\n",
        "\n",
        "- metrics for specific datasets like GLUE or SQuAD are provided out-of-the-box in a simple, convenient and consistant way integrated with the dataset,\n",
        "- metrics in `nlp` leverage the powerful backend to provide smart features out-of-the-box like support for distributed evaluation in PyTorch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uUoGMMVKvSV8",
        "colab_type": "text"
      },
      "source": [
        "## Using metrics\n",
        "\n",
        "Using metrics is pretty simple, they have two main methods: `.compute(predictions, references)` to directly compute the metric and `.add(predictions, references)` to only store some results if you want to do the evaluation in one go at the end.\n",
        "\n",
        "Here is a quick gist of a standard use of metrics (the simplest usage):\n",
        "```python\n",
        "import nlp\n",
        "bleu_metric = nlp.load_metric('bleu')\n",
        "\n",
        "# If you only have a single iteration, you can easily compute the score like this\n",
        "predictions = model(inputs)\n",
        "score = bleu_metric.compute(predictions, references)\n",
        "\n",
        "# If you have a loop, you can \"add\" your predictions and references at each iteration instead of having to save them yourself (the metric object store them efficiently for you)\n",
        "for batch in dataloader:\n",
        "    model_input, targets = batch\n",
        "    predictions = model(model_inputs)\n",
        "    bleu.add(predictions, targets)\n",
        "score = bleu_metric.compute()  # Compute the score from all the stored predictions/references\n",
        "```\n",
        "\n",
        "Here is a quick gist of a use in a distributed torch setup (should work for any python multi-process setup actually). It's pretty much identical to the second example above:\n",
        "```python\n",
        "import nlp\n",
        "# You need to give the total number of parallel python processes (num_process) and the id of each process (process_id)\n",
        "bleu = nlp.load_metric('bleu', process_id=torch.distributed.get_rank(),b num_process=torch.distributed.get_world_size())\n",
        "\n",
        "for batch in dataloader:\n",
        "    model_input, targets = batch\n",
        "    predictions = model(model_inputs)\n",
        "    bleu.add(predictions, targets)\n",
        "score = bleu_metric.compute()  # Compute the score on the first node by default (can be set to compute on each node as well)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ySL-vDadvSV8",
        "colab_type": "text"
      },
      "source": [
        "Example with a NER metric: `seqeval`"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f4uZym7MvSV9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "ner_metric = nlp.load_metric('seqeval')\n",
        "references = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]\n",
        "predictions =  [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]\n",
        "ner_metric.compute(predictions, references)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ctY6AIAilLdH",
        "colab_type": "text"
      },
      "source": [
        "# Adding a new dataset or a new metric\n",
        "\n",
        "They are two ways to add new datasets and metrics in `nlp`:\n",
        "\n",
        "- datasets can be added with a Pull-Request adding a script in the `datasets` folder of the [`nlp` repository](https://github.com/huggingface/nlp)\n",
        "\n",
        "=> once the PR is merged, the dataset can be instantiate by it's folder name e.g. `nlp.load_dataset('squad')`. If you want HuggingFace to host the data as well you will need to ask the HuggingFace team to upload the data.\n",
        "\n",
        "- datasets can also be added with a direct upload using `nlp` CLI as a user or organization (like for models in `transformers`). In this case the dataset will be accessible under the gien user/organization name, e.g. `nlp.load_dataset('thomwolf/squad')`. In this case you can upload the data yourself at the same time and in the same folder.\n",
        "\n",
        "We will add a full tutorial on how to add and upload datasets soon."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ypLjbtGrljk8",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}